package com.diven.spark.ml.learn.feature

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
import com.diven.spark.ml.learn.core.BaseTest
import com.diven.spark.ml.learn.core.BaseSpark
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.feature.DCT
import org.apache.spark.ml.feature.ElementwiseProduct

object ElementwiseProductTest extends BaseTest {
    
    def apply(): BaseSpark = new ElementwiseProductTest()
    
}

class ElementwiseProductTest extends BaseSpark{
    
     override def getDataFrame(sparkSession: SparkSession = this.getSparkSession()): DataFrame = {
         sparkSession.createDataFrame(Seq(
              (0, Vectors.dense(0.0, 1.0, -2.0, 3.0)),
              (1, Vectors.dense(-1.0, 2.0, 4.0, -7.0)),
              (2, Vectors.dense(14.0, -2.0, -5.0, 1.0))
         ))
         .toDF("id", "features")
     }
    
    override def execute(dataFrame: DataFrame) = {
        //特征名称
        var feature = "features"
        var feature_new = "features_elementwise"
        //权重向量
        val transformingVector = Vectors.dense(0.0, 1.0, 2.0, 3.0)
        //设置模型
        val elementwiseProduct = new ElementwiseProduct()
        .setInputCol(feature)                 //待变换的特征
        .setOutputCol(feature_new)            //变换后的特征名称
        .setScalingVec(transformingVector)    //权重向量
        //模型测试
        var transform = elementwiseProduct.transform(dataFrame)
        //show
        transform.show(100, 100)
        
        dataFrame.show(false)
        dataFrame.printSchema()
        transform.printSchema()
    }
    
}
